热门标签 | HotTags
当前位置:  开发笔记 > 人工智能 > 正文

【深度学习】|线性回归的简单实现

1概述本文的主要目的是通过实现最简单的线性回归模型,理解pytorch在数据导入、模型定义、、损失计算、优化迭代、自动求导和批次训练等方面的特点。2数据导入首先,生成真实的线性函数

1 概述

本文的主要目的是通过实现最简单的线性回归模型,理解pytorch在数据导入、模型定义、、损失计算、优化迭代、自动求导和批次训练等方面的特点。


2 数据导入

首先,生成真实的线性函数,参数为w和b;接着按照w和b的size来生成1000个样本数据



点击查看代码

import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2l
true_w = torch.tensor([2,-3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b,1000)

构造出数据的Dataset类,将其放入到DataLoader中方便进行批次训练,DataLoader可以实现对Dataset中的数据进行shuffle和批次大小的划分。



点击查看代码

def load_array(data_arrays, batch_size, is_train = True):
'''构造一个PyTorch数据迭代器'''
dataset = data.TensorDataset(*data_arrays)# 此处*的作用
return data.DataLoader(dataset, batch_size, shuffle = is_train)
batch_size = 10
data_iter = load_array((features, labels), batch_size)
next(iter(data_iter))


3 模型定义

使用框架预定义好的层,nn是神经网络的缩写



点击查看代码

from torch import nn
net = nn.Sequential(nn.Linear(2, 1))

初始化模型参数



点击查看代码

net[0].weight.data.normal_(0, 0.01)
net[0].bias.data.fill_(0)


4 定义损失函数



点击查看代码

loss = nn.MSELoss()


5 定义优化算法



点击查看代码

trainer = torch.optim.SGD(net.parameters(),lr = 0.03)


6 训练

使用随机小批量梯度下降法

训练过程中打印的loss是为研究者观察模型是否往参数逐步优化的方向变化而给的一个参考指标。

data_iter中的loss是指一个批次的损失。

trainer.zero_grad()是梯度清零

l.backward()是求解这一个批次的样本的导数和

trainer.step()# 以求得的导数和,结合优化器,更新参数 w和b,然后进行下一批次的训练



点击查看代码

# 训练
num_epochs = 3
for epoch in range(num_epochs):
for X, y in data_iter:
# print(X.shape)
l = loss(net(X), y)# 一个批次的损失
trainer.zero_grad()
l.backward()
trainer.step()
# print(features.shape)
l = loss(net(features), labels)# 整个数据集的损失
print(f'epoch {epoch + 1}, loss {l:f}')


7 打印结果



点击查看代码

w = net[0].weight.data
print('w的估计误差:', true_w - w.reshape(true_w.shape))
b = net[0].bias.data
print('b的估计误差:', true_b - b)

image



推荐阅读
author-avatar
sherklock
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有